import asyncio
import argparse
import json
import os
import re
import time
import traceback
import aiohttp  
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Any, Union
import random
from collections import defaultdict
from datetime import datetime

@dataclass
class NodeConfig:
    beam_width: int = 3
    max_retries: int = 3

class LLMClient:
    def __init__(self):
        self.model = "your model"  
        self.base_url = "your api" 
        self.token_counts = [0, 0]  
    
    async def generate(self, prompt: str, response_format: str = "text") -> str:
        try:
            extra_params = {
                "temperature": 0.3,
                "max_tokens": 8192,
                "top_p": 0.9
            }
            
            if response_format == "json_object":
                extra_params["response_format"] = {"type": "json_object"}
            
            async with aiohttp.ClientSession() as session:
                payload = {
                    "model": self.model,
                    "messages": [{"role": "user", "content": prompt}],
                    **extra_params
                }
                
                async with session.post(
                    f"{self.base_url}/chat/completions",
                    json=payload,
                    timeout=aiohttp.ClientTimeout(total=120)
                ) as response:
                    resp = await response.json()
                    

                    input_tokens = len(prompt) // 4
                    output_tokens = len(resp["choices"][0]["message"]["content"]) // 4
                    self.token_counts[0] += input_tokens
                    self.token_counts[1] += output_tokens
                    
                    return resp["choices"][0]["message"]["content"]
        except Exception as e:
            print(f"LLM Error: {str(e)}")
            raise

@dataclass
class ReasoningNode:
    node_id: str
    path: List[str] = field(default_factory=list)
    method: Dict[str, Any] = field(default_factory=dict)
    steps: List[str] = field(default_factory=list)
    score: int = 0
    constraints: Dict[str, Any] = field(default_factory=dict)
    state: str = "pending"
    answer: Optional[str] = None
    parent_id: Optional[str] = None
    children: List[str] = field(default_factory=list)
    question: Optional[str] = None  
    theory: Optional[Dict[str, Any]] = None 
    facts: List[str] = field(default_factory=list)
    options: Dict[str, str] = field(default_factory=dict)

class BaseReasoner:
    def __init__(self, dataset_name: str):
        self.llm = LLMClient()  
        self.config = NodeConfig()
        self.dataset_name = dataset_name
        self.nodes: Dict[str, ReasoningNode] = {}
        self.temp_list: List[str] = []
        self.current_node_id = 0
        self.logs: List[Dict[str, Any]] = []
        self.stats = {  
            "total_problems": 0,
            "correct_answers": 0,
            "incorrect_answers": 0,
            "accuracy": 0.0
        }
    
    def _create_node(self, **kwargs) -> ReasoningNode:
        node_id = f"N{self.current_node_id}"
        self.current_node_id += 1
        node = ReasoningNode(node_id=node_id, **kwargs)
        self.nodes[node_id] = node
        return node
    
    def _log_step(self, step: str, node_id: str, details: Dict[str, Any]):
        self.logs.append({
            "step": step,
            "node_id": node_id,
            "details": details
        })
    
    def update_stats(self, is_correct: bool):
        """Update statistics"""
        self.stats["total_problems"] += 1
        if is_correct:
            self.stats["correct_answers"] += 1
        else:
            self.stats["incorrect_answers"] += 1
        
        if self.stats["total_problems"] > 0:
            self.stats["accuracy"] = (
                self.stats["correct_answers"] / self.stats["total_problems"] * 100
            )